import pandas as pd
import numpy as np
import matplotlib.pyplot as plt


class LDA(object):
    '''
        二分类的LDA算法
    '''

    def __init__(self,_X,_Y,classnumber=2):

        self.X = _X 
        self.Y = _Y
        self.data_size = [_X.shape[0],_X.shape[1]]
        assert self.data_size[0] == len(_Y) # 确保数据量相同
        self.w = np.zeros(self.data_size[0])
        self.classnumber = classnumber
        self.x1 = self.X[self.Y==1]
        self.x2 = self.X[self.Y==0]
        #self.tags = set(self.Y.flatten().to_list())
        self.xmean=0





    def fit(self,isdraw=True):
        self.meanx1 = self.x1.mean(0,keepdims=True)
        self.meanx2 = self.x2.mean(0,keepdims=True)
        Sw = np.matmul((self.x1-self.meanx1).T,(self.x1-self.meanx1)) + np.matmul((self.x2-self.meanx2).T,(self.x2-self.meanx2))
        u, s, vh = np.linalg.svd(Sw) # sw = u diag(S) vT
        self.w =np.matmul( np.matmul( np.transpose(vh) , np.linalg.inv(np.diag(s)) ) , np.transpose(u))
        self.w = np.dot(self.w, (self.meanx1 - self.meanx2).T).reshape(1, -1) 
        # sw = np.dot((self.x1 -self.meanx1).T, self.x1 -self.meanx1) + np.dot((self.x2 -self.meanx2).T, self.x2 -self.meanx2)
        # self.w = np.dot(np.linalg.inv(sw), (self.meanx1 - self.meanx2).T).reshape(1, -1)  # (1, n)
        # sw-1 = v * diag(s)-1 * uT
        if isdraw:
            fig, ax = plt.subplots()
            ax.spines['right'].set_color('none')
            ax.spines['top'].set_color('none')
            ax.spines['left'].set_position(('data', 0))
            ax.spines['bottom'].set_position(('data', 0))

            plt.scatter(self.x1[:, 0], self.x1[:, 1], c='k', marker='o', label='good')
            plt.scatter(self.x2[:, 0], self.x2[:, 1], c='r', marker='x', label='bad')

            plt.xlabel('密度', labelpad=1)
            plt.ylabel('含糖量')
            plt.legend(loc='upper right')

            x_tmp = np.linspace(-0.05, 0.15)
            y_tmp = x_tmp * self.w[0,1] / self.w[0,0]
            plt.plot(x_tmp, y_tmp, '#808080', linewidth=1)

            wu = self.w / np.linalg.norm(self.w)

            # 正负样板店
            X0_project = np.dot(self.x1, np.dot(wu.T, wu))
            plt.scatter(X0_project[:, 0], X0_project[:, 1], c='r', s=15)
            for i in range(self.x1.shape[0]):
                plt.plot([self.x1[i, 0], X0_project[i, 0]], [self.x1[i, 1], X0_project[i, 1]], '--r', linewidth=1)

            X1_project = np.dot(self.x2, np.dot(wu.T, wu))
            plt.scatter(X1_project[:, 0], X1_project[:, 1], c='k', s=15)
            for i in range(self.x2.shape[0]):
                plt.plot([self.x2[i, 0], X1_project[i, 0]], [self.x2[i, 1], X1_project[i, 1]], '--k', linewidth=1)

            # 中心点的投影
            u0_project = np.dot(self.meanx1, np.dot(wu.T, wu))
            plt.scatter(u0_project[:, 0], u0_project[:, 1], c='#FF4500', s=60)
            u1_project = np.dot(self.meanx2, np.dot(wu.T, wu))
            plt.scatter(u1_project[:, 0], u1_project[:, 1], c='#696969', s=60)

            ax.annotate(r'u0 投影点',
                        xy=(u0_project[:, 0], u0_project[:, 1]),
                        xytext=(u0_project[:, 0] - 0.2, u0_project[:, 1] - 0.1),
                        size=13,
                        va="center", ha="left",
                        arrowprops=dict(arrowstyle="->",
                                        color="k",
                                        )
                        )

            ax.annotate(r'u1 投影点',
                        xy=(u1_project[:, 0], u1_project[:, 1]),
                        xytext=(u1_project[:, 0] - 0.1, u1_project[:, 1] + 0.1),
                        size=13,
                        va="center", ha="left",
                        arrowprops=dict(arrowstyle="->",
                                        color="k",
                                        )
                        )
            plt.axis("equal")  # 两坐标轴的单位刻度长度保存一致
            plt.show()
        return self

    def predict(self, X):
        project = np.dot(X, self.w.T)
        wu0 = np.dot(self.w, self.meanx1.T)
        wu1 = np.dot(self.w, self.meanx2.T)
        return (np.sum(np.abs(project - wu1.T),axis=1 ) > np.sum(np.abs(project - wu0.T),axis= 1)).astype(int)
        
    def cast(self,X):
        project = np.dot(X,self.w.T)
        return project
        
      


if __name__ == '__main__':
    data_path = "watermelon.csv"
    data = pd.read_csv(data_path).values

    x = data[:, 0:2].astype(float)
    ss = x.shape
    print(ss[0])
    y = data[:, 2]
    lda = LDA(x,y)
    lda.fit()
    print(lda.predict(x))  
    print(y)